{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n    @Author: King\\n    @Date: 2019.06.25\\n    @Purpose: Graph Convolutional Network\\n    @Introduction:   This is a gentle introduction of using DGL to implement \\n                    Graph Convolutional Networks (Kipf & Welling et al., \\n                    Semi-Supervised Classification with Graph Convolutional Networks). \\n                    We build upon the earlier tutorial on DGLGraph and demonstrate how DGL \\n                    combines graph with deep neural network and learn structural representations.\\n    @Datasets: \\n    @Link : \\n    @Reference : https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "    @Author: King\n",
    "    @Date: 2019.06.25\n",
    "    @Purpose: Graph Convolutional Network\n",
    "    @Introduction:   This is a gentle introduction of using DGL to implement \n",
    "                    Graph Convolutional Networks (Kipf & Welling et al., \n",
    "                    Semi-Supervised Classification with Graph Convolutional Networks). \n",
    "                    We build upon the earlier tutorial on DGLGraph and demonstrate how DGL \n",
    "                    combines graph with deep neural network and learn structural representations.\n",
    "    @Datasets: \n",
    "    @Link : \n",
    "    @Reference : https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understand Graph Attention Network\n",
    "\n",
    "From Graph Convolutional Network (GCN), we learned that combining local graph structure and node-level features yields good performance on node classification task. However, the way GCN aggregates is structure-dependent, which may hurt its generalizability.\n",
    "从图卷积网络（GCN），我们了解到结合局部图结构和节点级特征可以在节点分类任务上产生良好的性能。但是，GCN聚合的方式依赖于结构，这可能会损害其普遍性。\n",
    "\n",
    "\n",
    "One workaround is to simply average over all neighbor node features as in GraphSAGE. Graph Attention Network proposes an alternative way by weighting neighbor features with feature dependent and structure free normalization, in the style of attention\n",
    "一种解决方法是简单地平均所有邻居节点功能，如GraphSAGE中所示。图注意网络通过以注意的方式加权具有特征相关和结构自由规范化的邻居特征来提出另一种方法\n",
    "\n",
    "\n",
    "The goal of this tutorial:\n",
    "\n",
    "- Explain what is Graph Attention Network.\n",
    "- Demonstrate how it can be implemented in DGL.\n",
    "- Understand the attentions learnt.\n",
    "- Introduce to inductive learning.\n",
    "\n",
    "\n",
    "\n",
    "## Introducing Attention to GCN\n",
    "\n",
    "\n",
    "The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.\n",
    "GAT和GCN之间的关键区别在于如何聚合来自一跳邻域的信息。\n",
    "\n",
    "For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors:\n",
    "\n",
    "$$\n",
    "h_{i}^{(l+1)}=\\sigma\\left(\\sum_{j \\in \\mathcal{N}(i)} \\frac{1}{c_{i j}} W^{(l)} h_{j}^{(l)}\\right)\n",
    "$$\n",
    "\n",
    "\n",
    "GAT introduces the attention mechanism as a substitute for the statically normalized convolution operation. (GAT引入了注意机制作为静态归一化卷积运算的替代。) Below are the equations to compute the node embedding $h_{i}^{(l+1)}$ of layer l+1 from the embeddings of layer l:\n",
    "\n",
    "![](img/GAT_1.png)\n",
    "\n",
    "\n",
    "![](img/GAT_2.png)\n",
    "\n",
    "There are other details from the paper, such as dropout and skip connections. For the purpose of simplicity, we omit them in this tutorial and leave the link to the full example at the end for interested readers.\n",
    "\n",
    "In its essence, GAT is just a different aggregation function with attention over features of neighbors, instead of a simple mean aggregation.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GAT in DGL\n",
    "\n",
    "Let’s first have an overall impression about how a GATLayer module is implemented in DGL. Don’t worry, we will break down the four equations above one-by-one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class GATLayer(nn.Module):\n",
    "    def __init__(self, g, in_dim, out_dim):\n",
    "        super(GATLayer, self).__init__()\n",
    "        self.g = g\n",
    "        # equation (1)\n",
    "        self.fc = nn.Linear(in_dim, out_dim, bias=False)\n",
    "        # equation (2)\n",
    "        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)\n",
    "\n",
    "    def edge_attention(self, edges):\n",
    "        # edge UDF for equation (2)\n",
    "        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)\n",
    "        a = self.attn_fc(z2)\n",
    "        return {'e': F.leaky_relu(a)}\n",
    "\n",
    "    def message_func(self, edges):\n",
    "        # message UDF for equation (3) & (4)\n",
    "        return {'z': edges.src['z'], 'e': edges.data['e']}\n",
    "\n",
    "    def reduce_func(self, nodes):\n",
    "        # reduce UDF for equation (3) & (4)\n",
    "        # equation (3)\n",
    "        alpha = F.softmax(nodes.mailbox['e'], dim=1)\n",
    "        # equation (4)\n",
    "        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n",
    "        return {'h': h}\n",
    "\n",
    "    def forward(self, h):\n",
    "        # equation (1)\n",
    "        z = self.fc(h)\n",
    "        self.g.ndata['z'] = z\n",
    "        # equation (2)\n",
    "        self.g.apply_edges(self.edge_attention)\n",
    "        # equation (3) & (4)\n",
    "        self.g.update_all(self.message_func, self.reduce_func)\n",
    "        return self.g.ndata.pop('h')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-head Attention\n",
    "\n",
    "Analogous to multiple channels in ConvNet, GAT introduces multi-head attention to enrich the model capacity and to stabilize the learning process.(类似于ConvNet中的多个渠道，GAT引入了多头注意力，以丰富模型容量并稳定学习过程。) Each attention head has its own parameters and their outputs can be merged in two ways:\n",
    "\n",
    "![](img/GAT_3.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadGATLayer(nn.Module):\n",
    "    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):\n",
    "        super(MultiHeadGATLayer, self).__init__()\n",
    "        self.heads = nn.ModuleList()\n",
    "        for i in range(num_heads):\n",
    "            self.heads.append(GATLayer(g, in_dim, out_dim))\n",
    "        self.merge = merge\n",
    "\n",
    "    def forward(self, h):\n",
    "        head_outs = [attn_head(h) for attn_head in self.heads]\n",
    "        if self.merge == 'cat':\n",
    "            # concat on the output feature dimension (dim=1)\n",
    "            return torch.cat(head_outs, dim=1)\n",
    "        else:\n",
    "            # merge using average\n",
    "            return torch.mean(torch.stack(head_outs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Put everything together\n",
    "\n",
    "Now, we can define a two-layer GAT model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GAT(nn.Module):\n",
    "    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):\n",
    "        super(GAT, self).__init__()\n",
    "        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)\n",
    "        # Be aware that the input dimension is hidden_dim*num_heads since\n",
    "        # multiple head outputs are concatenated together. Also, only\n",
    "        # one attention head in the output layer.\n",
    "        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)\n",
    "\n",
    "    def forward(self, h):\n",
    "        h = self.layer1(h)\n",
    "        h = F.elu(h)\n",
    "        h = self.layer2(h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We then load the cora dataset using DGL’s built-in data module.\n",
    "\n",
    "from dgl import DGLGraph\n",
    "from dgl.data import citation_graph as citegrh\n",
    "\n",
    "def load_cora_data():\n",
    "    data = citegrh.load_cora()\n",
    "    features = torch.FloatTensor(data.features)\n",
    "    labels = torch.LongTensor(data.labels)\n",
    "    mask = torch.ByteTensor(data.train_mask)\n",
    "    g = data.graph\n",
    "    # add self loop\n",
    "    g.remove_edges_from(g.selfloop_edges())\n",
    "    g = DGLGraph(g)\n",
    "    g.add_edges(g.nodes(), g.nodes())\n",
    "    return g, features, labels, mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
      "  warnings.warn(msg)\n",
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\numpy\\core\\fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.\n",
      "  out=out, **kwargs)\n",
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\numpy\\core\\_methods.py:85: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00000 | Loss 1.9459 | Time(s) nan\n",
      "Epoch 00001 | Loss 1.9451 | Time(s) nan\n",
      "Epoch 00002 | Loss 1.9443 | Time(s) nan\n",
      "Epoch 00003 | Loss 1.9434 | Time(s) 0.2730\n",
      "Epoch 00004 | Loss 1.9426 | Time(s) 0.2697\n",
      "Epoch 00005 | Loss 1.9418 | Time(s) 0.2794\n",
      "Epoch 00006 | Loss 1.9410 | Time(s) 0.2815\n",
      "Epoch 00007 | Loss 1.9401 | Time(s) 0.2837\n",
      "Epoch 00008 | Loss 1.9392 | Time(s) 0.2826\n",
      "Epoch 00009 | Loss 1.9384 | Time(s) 0.2817\n",
      "Epoch 00010 | Loss 1.9375 | Time(s) 0.2800\n",
      "Epoch 00011 | Loss 1.9366 | Time(s) 0.2789\n",
      "Epoch 00012 | Loss 1.9356 | Time(s) 0.2777\n",
      "Epoch 00013 | Loss 1.9347 | Time(s) 0.2774\n",
      "Epoch 00014 | Loss 1.9338 | Time(s) 0.2772\n",
      "Epoch 00015 | Loss 1.9328 | Time(s) 0.2767\n",
      "Epoch 00016 | Loss 1.9318 | Time(s) 0.2760\n",
      "Epoch 00017 | Loss 1.9308 | Time(s) 0.2759\n",
      "Epoch 00018 | Loss 1.9298 | Time(s) 0.2753\n",
      "Epoch 00019 | Loss 1.9288 | Time(s) 0.2752\n",
      "Epoch 00020 | Loss 1.9278 | Time(s) 0.2753\n",
      "Epoch 00021 | Loss 1.9267 | Time(s) 0.2747\n",
      "Epoch 00022 | Loss 1.9256 | Time(s) 0.2743\n",
      "Epoch 00023 | Loss 1.9245 | Time(s) 0.2745\n",
      "Epoch 00024 | Loss 1.9234 | Time(s) 0.2744\n",
      "Epoch 00025 | Loss 1.9223 | Time(s) 0.2742\n",
      "Epoch 00026 | Loss 1.9211 | Time(s) 0.2743\n",
      "Epoch 00027 | Loss 1.9200 | Time(s) 0.2741\n",
      "Epoch 00028 | Loss 1.9188 | Time(s) 0.2738\n",
      "Epoch 00029 | Loss 1.9176 | Time(s) 0.2736\n"
     ]
    }
   ],
   "source": [
    "# The training loop is exactly the same as in the GCN tutorial.\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "\n",
    "g, features, labels, mask = load_cora_data()\n",
    "\n",
    "# create the model, 2 heads, each head has hidden size 8\n",
    "net = GAT(g,\n",
    "          in_dim=features.size()[1],\n",
    "          hidden_dim=8,\n",
    "          out_dim=7,\n",
    "          num_heads=2)\n",
    "\n",
    "# create optimizer\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
    "\n",
    "# main loop\n",
    "dur = []\n",
    "for epoch in range(30):\n",
    "    if epoch >= 3:\n",
    "        t0 = time.time()\n",
    "\n",
    "    logits = net(features)\n",
    "    logp = F.log_softmax(logits, 1)\n",
    "    loss = F.nll_loss(logp[mask], labels[mask])\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch >= 3:\n",
    "        dur.append(time.time() - t0)\n",
    "\n",
    "    print(\"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}\".format(\n",
    "        epoch, loss.item(), np.mean(dur)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
